{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# task5：多类型情感分析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在之前的所有学习中，我们的数据集对于情感的分析只有两个分类：正面或负面。当我们只有两个类时，我们的输出可以是单个标量，范围在 0 和 1 之间，表示示例属于哪个类。当我们有 2 个以上的例子时，我们的输出必须是一个 $C$ 维向量，其中 $C$ 是类的数量。\n",
    "\n",
    "在本次学习中，我们将对具有 6 个类的数据集执行分类。请注意，该数据集实际上并不是情感分析数据集，而是问题数据集，任务是对问题所属的类别进行分类。但是，本次学习中涵盖的所有内容都适用于任何包含属于 $C$ 类之一的输入序列的示例的数据集。\n",
    "\n",
    "下面，我们设置字段并加载数据集，与之前不同的是：\n",
    "\n",
    "第一，我们不需要在 `LABEL` 字段中设置 `dtype`。在处理多类问题时，PyTorch 期望标签被数字化为`LongTensor`。\n",
    "\n",
    "第二，这次我们使用的是`TREC`数据集而不是`IMDB`数据集。 `fine_grained` 参数允许我们使用细粒度标签（其中有50个类）或不使用（在这种情况下它们将是6个类）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py:740: UserWarning: [W094] Model 'en_core_web_sm' (2.2.0) specifies an under-constrained spaCy version requirement: >=2.2.0. This can lead to compatibility problems with older versions, or as new spaCy versions are released, because the model may say it's compatible when it's not. Consider changing the \"spacy_version\" in your meta.json to a version range, with a lower and upper pin. For example: >=3.1.2,<3.2.0\n",
      "  warnings.warn(warn_msg)\n"
     ]
    },
    {
     "ename": "OSError",
     "evalue": "[E053] Could not read config.cfg from D:\\ProgramData\\Anaconda3\\lib\\site-packages\\en_core_web_sm\\en_core_web_sm-2.2.0\\config.cfg",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mOSError\u001b[0m                                   Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-78a725e4bc2e>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdeterministic\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mTEXT\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mField\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtokenize\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'spacy'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtokenizer_language\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'en_core_web_sm'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     12\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m \u001b[0mLABEL\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mLabelField\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\legacy\\data\\field.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, sequential, use_vocab, init_token, eos_token, fix_length, dtype, preprocessing, postprocessing, lower, tokenize, tokenizer_language, include_lengths, batch_first, pad_token, unk_token, pad_first, truncate_first, stop_words, is_target)\u001b[0m\n\u001b[0;32m    159\u001b[0m         \u001b[1;31m# in case the tokenizer isn't picklable (e.g. spacy)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    160\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtokenizer_args\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtokenizer_language\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 161\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtokenize\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_tokenizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtokenizer_language\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    162\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minclude_lengths\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minclude_lengths\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    163\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbatch_first\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch_first\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\data\\utils.py\u001b[0m in \u001b[0;36mget_tokenizer\u001b[1;34m(tokenizer, language)\u001b[0m\n\u001b[0;32m    113\u001b[0m             \u001b[1;32mimport\u001b[0m \u001b[0mspacy\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    114\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 115\u001b[1;33m                 \u001b[0mspacy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mspacy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlanguage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    116\u001b[0m             \u001b[1;32mexcept\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    117\u001b[0m                 \u001b[1;31m# Model shortcuts no longer work in spaCy 3.0+, try using fullnames\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\__init__.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(name, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m     49\u001b[0m     \u001b[0mRETURNS\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mLanguage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mloaded\u001b[0m \u001b[0mnlp\u001b[0m \u001b[0mobject\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m     \"\"\"\n\u001b[1;32m---> 51\u001b[1;33m     return util.load_model(\n\u001b[0m\u001b[0;32m     52\u001b[0m         \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     53\u001b[0m     )\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model\u001b[1;34m(name, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m    319\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mget_lang_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreplace\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"blank:\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    320\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mis_package\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m  \u001b[1;31m# installed as package\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 321\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mload_model_from_package\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    322\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mPath\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m  \u001b[1;31m# path to model data directory\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    323\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mload_model_from_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mPath\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model_from_package\u001b[1;34m(name, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m    352\u001b[0m     \"\"\"\n\u001b[0;32m    353\u001b[0m     \u001b[0mcls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 354\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mcls\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    355\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    356\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\en_core_web_sm\\__init__.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(**overrides)\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     11\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mload_model_from_init_py\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m__file__\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model_from_init_py\u001b[1;34m(init_file, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m    512\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mmodel_path\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    513\u001b[0m         \u001b[1;32mraise\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mErrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mE052\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 514\u001b[1;33m     return load_model_from_path(\n\u001b[0m\u001b[0;32m    515\u001b[0m         \u001b[0mdata_path\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    516\u001b[0m         \u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model_from_path\u001b[1;34m(model_path, meta, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m    386\u001b[0m     \u001b[0mconfig_path\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel_path\u001b[0m \u001b[1;33m/\u001b[0m \u001b[1;34m\"config.cfg\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    387\u001b[0m     \u001b[0moverrides\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdict_to_dot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 388\u001b[1;33m     \u001b[0mconfig\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_config\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverrides\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    389\u001b[0m     \u001b[0mnlp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_model_from_config\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    390\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mnlp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_disk\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverrides\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_config\u001b[1;34m(path, overrides, interpolate)\u001b[0m\n\u001b[0;32m    543\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    544\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig_path\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig_path\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig_path\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 545\u001b[1;33m             \u001b[1;32mraise\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mErrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mE053\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"config.cfg\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    546\u001b[0m         return config.from_disk(\n\u001b[0;32m    547\u001b[0m             \u001b[0mconfig_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverrides\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minterpolate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minterpolate\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mOSError\u001b[0m: [E053] Could not read config.cfg from D:\\ProgramData\\Anaconda3\\lib\\site-packages\\en_core_web_sm\\en_core_web_sm-2.2.0\\config.cfg"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torchtext.legacy import data\n",
    "from torchtext.legacy import datasets\n",
    "import random\n",
    "\n",
    "SEED = 1234\n",
    "\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "TEXT = data.Field(tokenize = 'spacy',tokenizer_language = 'en_core_web_sm')\n",
    "\n",
    "LABEL = data.LabelField()\n",
    "\n",
    "train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)\n",
    "\n",
    "train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面我们看一个训练集的示例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['What', 'is', 'a', 'Cartesian', 'Diver', '?'], 'label': 'DESC'}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vars(train_data[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们将构建词汇表。 由于这个数据集很小（只有约 3800 个训练样本），它的词汇量也非常小（约 7500 个不同单词，即one-hot向量为7500维），这意味着我们不需要像以前一样在词汇表上设置“max_size”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_VOCAB_SIZE = 25_000\n",
    "\n",
    "TEXT.build_vocab(train_data, \n",
    "                 max_size = MAX_VOCAB_SIZE, \n",
    "                 vectors = \"glove.6B.100d\", \n",
    "                 unk_init = torch.Tensor.normal_)\n",
    "\n",
    "LABEL.build_vocab(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们可以检查标签。\n",
    "\n",
    "6 个标签（对于非细粒度情况）对应于数据集中的 6 类问题：\n",
    "- `HUM`：关于人类的问题\n",
    "- `ENTY`：关于实体的问题的\n",
    "- `DESC`：关于要求提供描述的问题\n",
    "- `NUM`：关于答案为数字的问题\n",
    "- `LOC`：关于答案是位置的问题\n",
    "- `ABBR`：关于询问缩写的问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(None, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
     ]
    }
   ],
   "source": [
    "print(LABEL.vocab.stoi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "与往常一样，我们设置了迭代器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
     ]
    }
   ],
   "source": [
    "BATCH_SIZE = 64\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size = BATCH_SIZE, \n",
    "    device = device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将使用上一个notebook中的CNN模型，但是教程中涵盖的任何模型都适用于该数据集。 唯一的区别是现在 `output_dim` 是 $C$维而不是 $2$维。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
    "                 dropout, pad_idx):\n",
    "        \n",
    "        super().__init__()\n",
    "        \n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        \n",
    "        self.convs = nn.ModuleList([\n",
    "                                    nn.Conv2d(in_channels = 1, \n",
    "                                              out_channels = n_filters, \n",
    "                                              kernel_size = (fs, embedding_dim)) \n",
    "                                    for fs in filter_sizes\n",
    "                                    ])\n",
    "        \n",
    "        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, text):\n",
    "        \n",
    "        #text = [sent len, batch size]\n",
    "        \n",
    "        text = text.permute(1, 0)\n",
    "                \n",
    "        #text = [batch size, sent len]\n",
    "        \n",
    "        embedded = self.embedding(text)\n",
    "                \n",
    "        #embedded = [batch size, sent len, emb dim]\n",
    "        \n",
    "        embedded = embedded.unsqueeze(1)\n",
    "        \n",
    "        #embedded = [batch size, 1, sent len, emb dim]\n",
    "        \n",
    "        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
    "            \n",
    "        #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n",
    "        \n",
    "        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
    "        \n",
    "        #pooled_n = [batch size, n_filters]\n",
    "        \n",
    "        cat = self.dropout(torch.cat(pooled, dim = 1))\n",
    "\n",
    "        #cat = [batch size, n_filters * len(filter_sizes)]\n",
    "            \n",
    "        return self.fc(cat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们定义我们的模型，确保将输出维度： `OUTPUT_DIM` 设置为 $C$。 我们可以通过使用 `LABEL` 词汇的大小轻松获得 $C$，就像我们使用 `TEXT` 词汇的长度来获取输入词汇的大小一样。\n",
    "\n",
    "此数据集中的示例比 IMDb 数据集中的示例小很多，因此我们将使用较小的`filter`大小。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = len(TEXT.vocab)\n",
    "EMBEDDING_DIM = 100\n",
    "N_FILTERS = 100\n",
    "FILTER_SIZES = [2,3,4]\n",
    "OUTPUT_DIM = len(LABEL.vocab)\n",
    "DROPOUT = 0.5\n",
    "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
    "\n",
    "model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "检查参数的数量，我们可以看到较小的`filter`大小意味着我们的参数是 IMDb 数据集上 CNN 模型的三分之一。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 841,806 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之后，我们将加载我们的预训练embedding。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1117, -0.4966,  0.1631,  ...,  1.2647, -0.2753, -0.1325],\n",
       "        [-0.8555, -0.7208,  1.3755,  ...,  0.0825, -1.1314,  0.3997],\n",
       "        [ 0.1638,  0.6046,  1.0789,  ..., -0.3140,  0.1844,  0.3624],\n",
       "        ...,\n",
       "        [-0.3110, -0.3398,  1.0308,  ...,  0.5317,  0.2836, -0.0640],\n",
       "        [ 0.0091,  0.2810,  0.7356,  ..., -0.7508,  0.8967, -0.7631],\n",
       "        [ 0.5831, -0.2514,  0.4156,  ..., -0.2735, -0.8659, -1.4063]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pretrained_embeddings = TEXT.vocab.vectors\n",
    "\n",
    "model.embedding.weight.data.copy_(pretrained_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后将用0来初始化未知的权重和padding参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
    "\n",
    "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
    "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "与之前notebook的另一个不同之处是我们的损失函数。  `BCEWithLogitsLoss` 一般用来做二分类，而 `CrossEntropyLoss`用来做多分类，`CrossEntropyLoss` 对我们的模型输出执行 *softmax* 函数，损失由该函数和标签之间的 *交叉熵 * 给出。\n",
    "\n",
    "一般来说：\n",
    "- 当我们的示例仅属于 $C$ 类之一时，使用 `CrossEntropyLoss`\n",
    "- 当我们的示例仅属于 2 个类（0 和 1）时使用 `BCEWithLogitsLoss`，并且也用于我们的示例属于 0 和 $C$ 之间的类（也称为多标签分类）的情况。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "model = model.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之前，我们有一个函数可以计算二进制标签情况下的准确度，我们说如果值超过 0.5，那么我们会假设它是正的。 在我们有超过 2 个类的情况下，我们的模型输出一个 $C$ 维向量，其中每个元素的值是示例属于该类的置信度。\n",
    "\n",
    "例如，在我们的标签中，我们有：'HUM' = 0、'ENTY' = 1、'DESC' = 2、'NUM' = 3、'LOC' = 4 和 'ABBR' = 5。如果我们的输出 模型是这样的：**[5.1, 0.3, 0.1, 2.1, 0.2, 0.6]** 这意味着该模型确信该示例属于第 0 类：这是一个关于人类的问题，并且略微相信该示例属于该第3类：关于数字的问题。\n",
    "\n",
    "我们通过执行 `argmax` 来获取批次中每个元素的预测最大值的索引，然后计算它与实际标签相等的次数来计算准确度。 然后我们对整个批次进行平均。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def categorical_accuracy(preds, y):\n",
    "    \"\"\"\n",
    "    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
    "    \"\"\"\n",
    "    top_pred = preds.argmax(1, keepdim = True)\n",
    "    correct = top_pred.eq(y.view_as(top_pred)).sum()\n",
    "    acc = correct.float() / y.shape[0]\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练循环与之前类似，`CrossEntropyLoss`期望输入数据为 **[batch size, n classes]** ，标签为 **[batch size]** 。\n",
    "\n",
    "标签默认需要是一个 `LongTensor`类型的数据，因为我们没有像以前那样将 `dtype` 设置为 `FloatTensor`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    for batch in iterator:\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        predictions = model(batch.text)\n",
    "        \n",
    "        loss = criterion(predictions, batch.label)\n",
    "        \n",
    "        acc = categorical_accuracy(predictions, batch.label)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "像之前一样对循环进行评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for batch in iterator:\n",
    "\n",
    "            predictions = model(batch.text)\n",
    "            \n",
    "            loss = criterion(predictions, batch.label)\n",
    "            \n",
    "            acc = categorical_accuracy(predictions, batch.label)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "            epoch_acc += acc.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator), epoch_acc / len(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
      "  warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 1.312 | Train Acc: 47.11%\n",
      "\t Val. Loss: 0.947 |  Val. Acc: 66.41%\n",
      "Epoch: 02 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.870 | Train Acc: 69.18%\n",
      "\t Val. Loss: 0.741 |  Val. Acc: 74.14%\n",
      "Epoch: 03 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.675 | Train Acc: 76.32%\n",
      "\t Val. Loss: 0.621 |  Val. Acc: 78.49%\n",
      "Epoch: 04 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.506 | Train Acc: 83.97%\n",
      "\t Val. Loss: 0.547 |  Val. Acc: 80.32%\n",
      "Epoch: 05 | Epoch Time: 0m 0s\n",
      "\tTrain Loss: 0.373 | Train Acc: 88.23%\n",
      "\t Val. Loss: 0.487 |  Val. Acc: 82.92%\n"
     ]
    }
   ],
   "source": [
    "N_EPOCHS = 5\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "\n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
    "    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
    "    \n",
    "    end_time = time.time()\n",
    "\n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tut5-model.pt')\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后，在测试集上运行我们的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.415 | Test Acc: 86.07%\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tut5-model.pt'))\n",
    "\n",
    "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
    "\n",
    "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "类似于我们创建一个函数来预测任何给定句子的情绪，我们现在可以创建一个函数来预测给定问题的类别。\n",
    "\n",
    "这里唯一的区别是，我们没有使用 sigmoid 函数将输入压缩在 0 和 1 之间，而是使用 `argmax` 来获得最高的预测类索引。 然后我们使用这个索引和标签 vocab 来获得可读的标签string。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import spacy\n",
    "nlp = spacy.load('en_core_web_sm')\n",
    "\n",
    "def predict_class(model, sentence, min_len = 4):\n",
    "    model.eval()\n",
    "    tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
    "    if len(tokenized) < min_len:\n",
    "        tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
    "    indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
    "    tensor = torch.LongTensor(indexed).to(device)\n",
    "    tensor = tensor.unsqueeze(1)\n",
    "    preds = model(tensor)\n",
    "    max_preds = preds.argmax(dim = 1)\n",
    "    return max_preds.item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，让我们在几个不同的问题上尝试一下……"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 0 = HUM\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"Who is Keyser Söze?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 3 = NUM\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"How many minutes are in six hundred and eighteen hours?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 4 = LOC\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"What continent is Bulgaria in?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted class is: 5 = ABBR\n"
     ]
    }
   ],
   "source": [
    "pred_class = predict_class(model, \"What does WYSIWYG stand for?\")\n",
    "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
